{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8vMo9TxNBl9t"
      },
      "source": [
        "# CLIP ImageNet Zero-Shot Performance\n",
        "\n",
        "*Licensed under the Apache License, Version 2.0.*\n",
        "\n",
        "This is a colab for clip's zero-shot ImageNet robustness evaluation\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DD02wux82iLo"
      },
      "source": [
        "## 1. Reproducing CLIP zero-shot ImageNet classification performance\n",
        "\n",
        "Replicating the latest OpenAI [Colab](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qUAPAyyJIX61"
      },
      "outputs": [],
      "source": [
        "# This section is built based on \n",
        "\n",
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZrHcKMdltCnX"
      },
      "outputs": [],
      "source": [
        "%pylab inline\n",
        "import tensorflow_datasets as tfds\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from tqdm import tqdm\n",
        "import tensorflow as tf\n",
        "import random\n",
        "from collections import Counter\n",
        "\n",
        "from scenic.projects.baselines.clip import model as clip\n",
        "from scenic.projects.baselines.clip import tokenizer as clip_tokenizer\n",
        "from gvt.libml import preprocess as gvt_preprocess"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "05NWZ85TVscl"
      },
      "outputs": [],
      "source": [
        "#@title ImageNet classNames\n",
        "clip.IMAGENET_CLASSES = multimodal_utils._ZEROSHOT_CLASS_NAMES['imagenet']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JIzYXXxzTAzZ"
      },
      "outputs": [],
      "source": [
        "model_name = 'vit_b16'\n",
        "\n",
        "model = clip.MODELS[model_name]()\n",
        "vars = clip.load_model_vars(model_name)\n",
        "\n",
        "encode_text = jax.jit(lambda texts: model.apply(vars, texts, method=model.encode_text))\n",
        "encode_image = jax.jit(lambda x: model.apply(vars, x, method=model.encode_image))\n",
        "\n",
        "tokenize_fn = clip_tokenizer.build_tokenizer()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lRjGzv6sosYZ"
      },
      "outputs": [],
      "source": [
        "def permute_words(text):\n",
        "  words = text.split(' ')\n",
        "  random.shuffle(words)\n",
        "  return ' '.join(words)\n",
        "\n",
        "def zeroshot_classifier(classnames, templates, permute=False):\n",
        "  zeroshot_weights = []\n",
        "  permute_fn = permute_words if permute else lambda x: x\n",
        "  for classname in tqdm(classnames):\n",
        "    texts = [permute_fn(template.format(classname)) for template in templates]\n",
        "    class_embeddings = encode_text(tokenize_fn(texts))\n",
        "    class_embedding = class_embeddings.mean(0)\n",
        "    class_embedding /= jnp.linalg.norm(class_embedding)\n",
        "    zeroshot_weights.append(class_embedding)\n",
        "  return jnp.stack(zeroshot_weights, axis=1)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nw6OUXuMpQ3Q"
      },
      "outputs": [],
      "source": [
        "# Readout weights with prompt engineering\n",
        "weights_prompteng = zeroshot_classifier(clip.IMAGENET_CLASSES, clip.PROMPTS)\n",
        "\n",
        "# Readout weights with modified ImageNet class names only\n",
        "weights_name = zeroshot_classifier(clip.IMAGENET_CLASSES, ['{}'])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HLxQkmIIaeH2"
      },
      "outputs": [],
      "source": [
        "def preprocess(batch, size=224):\n",
        "  batch = tf.image.convert_image_dtype(batch, dtype=tf.float32)\n",
        "  return gvt_preprocess.central_crop(gvt_preprocess.resize_small(batch, size, size), size, size)\n",
        "\n",
        "def normalize(img):\n",
        "  return (img - clip.IMAGE_MEAN) / clip.IMAGE_STD\n",
        "  \n",
        "def unnormalize(x):\n",
        "  return x * clip.IMAGE_STD + clip.IMAGE_MEAN"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VZx-2xM2uflF"
      },
      "outputs": [],
      "source": [
        "def load_dataset(dataset='imagenet2012', split='validation', batch_size=1024):\n",
        "  ds = tfds.load(dataset, split=split)\n",
        "  def _preprocess(d):\n",
        "    d['image'] = normalize(preprocess(d['image']))\n",
        "    return d\n",
        "  def _prepare(d):\n",
        "    return jax.tree.map(lambda x: x._numpy(), d)\n",
        "  batched_dataset = ds.map(_preprocess).batch(batch_size)\n",
        "  batched_dataset = map(_prepare, batched_dataset)\n",
        "  return batched_dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nlfoPT44NafG"
      },
      "outputs": [],
      "source": [
        "def compute_image_embeddings(dset):\n",
        "  embeddings = []\n",
        "  labels = []\n",
        "  for batch in tqdm(dset):\n",
        "    embeddings.append(encode_image(batch['image']))\n",
        "    labels.append(batch['label'])\n",
        "  return jnp.vstack(embeddings), jnp.hstack(labels)\n",
        "  \n",
        "def compute_accuracy(logits, labels):\n",
        "  top_probs, top_labels = jax.lax.top_k(logits, 5)\n",
        "  top1 = 100 * jnp.mean(top_labels[:, 0] == labels)\n",
        "  top5 = 100 * jnp.sum(top_labels == labels[:, None]) / labels.shape[0]\n",
        "  return top1, top5, top_probs, top_labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7IRsBlmgN0nz"
      },
      "outputs": [],
      "source": [
        "dset = load_dataset('imagenet2012')\n",
        "embeddings, labels = compute_image_embeddings(dset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bV7jIm8sFjHD"
      },
      "outputs": [],
      "source": [
        "# Compute accuracy with class names and prompt engineering.\n",
        "# Computing accuracy on ImageNet validation set on a TPUv2 takes ~2 minutes.\n",
        "# Note: use np.matmul() get more accurate and consistent result and @\n",
        "logits_prompteng = np.matmul(embeddings, weights_prompteng)\n",
        "logits_name = np.matmul(embeddings, weights_name)\n",
        "# logits_prompteng = embeddings @ weights_prompteng\n",
        "# logits_name = embeddings @ weights_name\n",
        "top1_prompt, top5_prompt, top5_prompt_probs, top5_prompt_labels = compute_accuracy(logits_prompteng, labels) # top5_labels are the top-5 prediction\n",
        "top1_name, top5_name, top5_name_probs, top5_name_labels = compute_accuracy(logits_name, labels)\n",
        "print(f'Prompt Engineering: top1={top1_prompt:.2f}%, top5={top5_prompt:.2f}%')\n",
        "print(f'Class Names: top1={top1_name:.2f}%, top5={top5_name:.2f}%')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SLeqlvKV3ZpT"
      },
      "source": [
        "## 2. Build hierarchy from WordNet"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xMc8VFmpzH8p"
      },
      "source": [
        "### WordNet parse"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GI8y4W79zH8p"
      },
      "outputs": [],
      "source": [
        "words_map = {}\n",
        "child_map = {}\n",
        "parent_map = {}\n",
        "gloss_map = {} # description\n",
        "\n",
        "blank = ' '\n",
        "comma_blank = ', '"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gSGZoG3CzH8p"
      },
      "outputs": [],
      "source": [
        "#obtain wordnet_id mappings for all words\n",
        "with tf.io.gfile.GFile(words_path, mode='r') as f:\n",
        "\tfor line in f:\n",
        "\t\tline_split = line.split() # use blank \" \" to split\n",
        "\t\twnid = line_split[0] # e.g., 'n03200357'\n",
        "\t\twords = line_split[1:] # e.g., ['electric,', 'electric', 'automobile,', 'electric', 'car'],\n",
        "\t\twords_map[wnid] = words\n",
        "f.close()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KAeBPSSxzH8q"
      },
      "outputs": [],
      "source": [
        "#obtain wordnet_id mappings for all word description\n",
        "with tf.io.gfile.GFile(gloss_path, mode='r') as f:\n",
        "\tfor line in f:\n",
        "\t\tline_split = line.split()\n",
        "\t\twnid = line_split[0]\n",
        "\t\tgloss = blank.join(line_split[1:])\n",
        "\t\tgloss_map[wnid] = gloss\n",
        "f.close()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "awkQS7mnzH8q"
      },
      "outputs": [],
      "source": [
        "#obtain wordnet_id mappings for all parents-children\n",
        "with tf.io.gfile.GFile(child_map_path, mode='r') as f:\n",
        "\tfor line in f:\n",
        "\t\tparent, child = line.split()\n",
        "\t\tparent_map[child] = parent\n",
        "\t\tif parent not in child_map:\n",
        "\t\t\tchild_map[parent] = [child]\n",
        "\t\telse:\n",
        "\t\t\tchild_map[parent].append(child)\n",
        "f.close()\t"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zDv2CRlbzH8q"
      },
      "outputs": [],
      "source": [
        "# find details given wordnet_id\n",
        "category = 'n02690373' #'n02084071' is dog\n",
        "descendants = []\n",
        "ancestors = []\n",
        "# find Descendants and Ancestors\n",
        "print(words_map[category])\n",
        "print(gloss_map[category]+\"\\n\")\n",
        "#list all children\n",
        "print(\"Descendants:\\n\")\n",
        "if category in child_map:\n",
        "  search = [child for child in child_map[category]]\n",
        "while search: # go over all children (BFS)\n",
        "  node = search.pop()\n",
        "  print(\"\\t\"+ blank.join(words_map[node])+\"\\n\")\n",
        "  descendants.append(blank.join(words_map[node])) # keep all descendant\n",
        "  if node in child_map: #has children\n",
        "    [search.append(child) for child in child_map[node]]\n",
        "\n",
        "#list all parents\n",
        "print(\"Ancestors:\\n\")\n",
        "if category in parent_map:\n",
        "  node = parent_map[category] # only one parent class\n",
        "else:\n",
        "  node = category\n",
        "while node in parent_map: # one way go up\n",
        "  print(\"\\t\"+ blank.join(words_map[node])+\"\\n\")\n",
        "  ancestors.append(blank.join(words_map[node])) # keep all ancestor\n",
        "  node = parent_map[node]\n",
        "print(\"finish\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gReUKpVSzH8r"
      },
      "outputs": [],
      "source": [
        "# get imagenet_id: wordnet_id mapping. e.g., '0': 'n01440764'\n",
        "index_wdid = {}\n",
        "with tf.io.gfile.GFile(imagenet_label_to_wordnet_file, mode='r') as f:\n",
        "  for line in f:\n",
        "    if \"{'id'\" in line:\n",
        "      index_wdid[line.split(\": {'\")[0].split(\"{\")[-1].split(\" \")[-1]] = 'n' + line.split(\"-n\")[0].split(\"'\")[-1]      # {0: {'id': '01440764-n',"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wJ1xkt89qSSk"
      },
      "source": [
        "build hierarchy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SRMCK8HZzH8r"
      },
      "outputs": [],
      "source": [
        "# find ancestor, descendants and children for each class, may takes minites\n",
        "an_des_dict_json = {}\n",
        "for index in range(len(clip.IMAGENET_CLASSES)): # for 1000 classes\n",
        "  category = index_wdid[str(index)]\n",
        "  an_des_dict_json[str(index)] = {}\n",
        "  descendants = []\n",
        "  ancestors = []\n",
        "  children = []\n",
        "  # find Descendants and Ancestors\n",
        "  #list all children\n",
        "  if category in child_map:\n",
        "    search = [child for child in child_map[category]] # here is only the child\n",
        "    children = [blank.join(words_map[ele]) for ele in search]\n",
        "  while search: # go over all children BFS priority queue\n",
        "    node = search.pop()\n",
        "    descendants.append(blank.join(words_map[node])) # keep all descendant\n",
        "    if node in child_map: #has children\n",
        "      [search.append(child) for child in child_map[node]]\n",
        "\n",
        "  #list all parents\n",
        "  node = parent_map[category] if category in parent_map else category\n",
        "  while node in parent_map: # one way go up\n",
        "    ancestors.append(blank.join(words_map[node])) # keep all ancestor\n",
        "    node = parent_map[node]\n",
        "  \n",
        "  # save\n",
        "  an_des_dict_json[str(index)][\"wdid\"] = category\n",
        "  an_des_dict_json[str(index)][\"words_map\"] = blank.join(words_map[category])\n",
        "  an_des_dict_json[str(index)][\"clip_words_map\"] = clip.IMAGENET_CLASSES[index]\n",
        "  an_des_dict_json[str(index)][\"ancestors\"] = ancestors\n",
        "  an_des_dict_json[str(index)][\"descendants\"] = descendants\n",
        "  an_des_dict_json[str(index)][\"children\"] = children"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FJtim-_cJpVs"
      },
      "source": [
        "### Utility functions for finding descendants and ancestors"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FcHqCy_PAP3s"
      },
      "outputs": [],
      "source": [
        "# TODO(yunhaoge) Add examples for using the functions\n",
        "\n",
        "def bottom_up_hierarchy(max_depth_ancestor=0, max_depth_descendant=1, use_children=True):\n",
        "  \"\"\"Obtain the class list by considering hierarchy\n",
        "  \"\"\"\n",
        "  # Consider ancedescendant into clip zero-shot \n",
        "  imagenet_bottom_up_first_index_list = [] # start index of each imagenet class, used for following aggregation\n",
        "  imagenet_bottom_up_all_classes = []\n",
        "  imagenet_bottom_up_mapping_index_list = [] # index details of each imagenet class\n",
        "  i = 0\n",
        "  for imagenet_idx, classname in enumerate(clip.IMAGENET_CLASSES):\n",
        "    imagenet_bottom_up_first_index_list.append(i) \n",
        "    imagenet_bottom_up_all_classes.append(classname) # add original class name\n",
        "    class_an_des_list = [] # relevant id for specific class\n",
        "    i += 1 # add index first\n",
        "    class_an_des_list.append(i)\n",
        "    node_ancestors = an_des_dict_json[str(imagenet_idx)][\"ancestors\"]\n",
        "    if use_children: # consider only children\n",
        "      node_descendants = an_des_dict_json[str(imagenet_idx)][\"children\"]\n",
        "    else: \n",
        "      node_descendants = an_des_dict_json[str(imagenet_idx)][\"descendants\"]\n",
        "    \n",
        "    for an_idx, ancestor in enumerate(node_ancestors): # select ancestors\n",
        "      if an_idx \u003c max_depth_ancestor:\n",
        "        if comma_blank in ancestor: # contains synonym, more than one, keep only one\n",
        "          imagenet_bottom_up_all_classes.append(ancestor.split(comma_blank)[0])\n",
        "          i += 1\n",
        "          class_an_des_list.append(i)\n",
        "        else:\n",
        "          imagenet_bottom_up_all_classes.append(ancestor)\n",
        "          i += 1\n",
        "          class_an_des_list.append(i)\n",
        "    for des_idx, descentant in enumerate(node_descendants): # select descendants\n",
        "      if des_idx \u003c max_depth_descendant:\n",
        "        if comma_blank in descentant: # contains synonym, more than one, keep only one\n",
        "          imagenet_bottom_up_all_classes.append(descentant.split(comma_blank)[0])\n",
        "          i += 1\n",
        "          class_an_des_list.append(i)\n",
        "        else:\n",
        "          imagenet_bottom_up_all_classes.append(descentant)\n",
        "          i += 1\n",
        "          class_an_des_list.append(i)\n",
        "    imagenet_bottom_up_mapping_index_list.append(class_an_des_list)\n",
        "  return imagenet_bottom_up_all_classes, imagenet_bottom_up_first_index_list, imagenet_bottom_up_mapping_index_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RJxTyrVFt2_k"
      },
      "outputs": [],
      "source": [
        "def top_down_hierarchy(interest_case, clip_templete, use_LCA = True, use_ancestor = False):\n",
        "  \"\"\"Obtain the word list by adding LCA.\"\"\"\n",
        "  # interest_case\n",
        "  logits_ori_all = []\n",
        "  for idx, interest_case_ele in enumerate(tqdm(interest_case)):\n",
        "    # original word embedding\n",
        "    word_list_ori = np.array(clip.IMAGENET_CLASSES)[np.array(interest_case_ele)].tolist()\n",
        "    if use_LCA: # add LCA/A\n",
        "      if use_ancestor: # use Ancestor\n",
        "        word_list =  [word + blank + A_all[idx] for word in word_list_ori]\n",
        "      else: # use LCA\n",
        "        word_list =  [word + blank + LCA_all[idx] for word in word_list_ori]\n",
        "    else:\n",
        "      word_list =  [word for word in word_list_ori]\n",
        "  return word_list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Vnd06rNSpckx"
      },
      "outputs": [],
      "source": [
        "# fine the LCA for each image top 5\n",
        "def find_LCA(top5_ancestor): # general\n",
        "  \"\"\"Obtain the Lowest comman ancestor of top 5 classes.\n",
        "\n",
        "  Args:\n",
        "      top5_ancestor: List of the ancestors for each candidate class\n",
        "  Returns:\n",
        "      LCA: Str, lowest comman ancestor\n",
        "  \"\"\"\n",
        "  LCA = 'physical entity'\n",
        "  # while (1) still have ancestor\n",
        "  while min([len(ele) for ele in top5_ancestor]) \u003e 0 : \n",
        "    # find the highest for each class in topk\n",
        "    current_roots = [topk_ancestor.pop() for topk_ancestor in top5_ancestor]\n",
        "    current_roots_freq = Counter(current_roots)\n",
        "    current_roots_freq = sorted(current_roots_freq.items(), key=lambda x: x[1], reverse=True) # become a list e.g.,[('ee', 2), ('ww', 1), ('cc', 1)]\n",
        "    majority, majority_freq = current_roots_freq[0]\n",
        "    if majority_freq == 5:\n",
        "      LCA = majority if comma_blank not in majority else majority.split(comma_blank)[0]\n",
        "  return LCA"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uMul_5Icp7QL"
      },
      "source": [
        "## Uncertainty Estimation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CGJuTfRG-UWj"
      },
      "outputs": [],
      "source": [
        "# discrete confidence score calculation\n",
        "def discrete_uncertainty_estimation(prompts=clip.PROMPTS, name_prediction=top5_name_labels[:, 0], prompt_prediction=top5_prompt_labels[:, 0]):\n",
        "  \"\"\"Obtain discrete confidence score for each image\"\"\"\n",
        "\n",
        "  # first half\n",
        "  prompteng_firsthalf = zeroshot_classifier(clip.IMAGENET_CLASSES, prompts[0:int(len(prompts)/2)]) #first half [0:40]\n",
        "  logits_firsthalf = np.matmul(embeddings, prompteng_firsthalf) # compute the logits\n",
        "  probs_firsthalf = np.array(jax.nn.softmax(logits_firsthalf, axis=-1))\n",
        "  preds_firsthalf = jnp.argmax(probs_firsthalf, axis=-1)\n",
        "\n",
        "  # 2nd half\n",
        "  prompteng_2ndhalf = zeroshot_classifier(clip.IMAGENET_CLASSES, prompts[int(len(prompts)/2):]) #second half [40:80]\n",
        "  logits_2ndhalf = np.matmul(embeddings, prompteng_2ndhalf) # compute the logits\n",
        "  probs_2ndhalf = np.array(jax.nn.softmax(logits_2ndhalf, axis=-1))\n",
        "  preds_2ndhalf = jnp.argmax(probs_2ndhalf, axis=-1)\n",
        "  \n",
        "  # consider all four different kinds of prompts\n",
        "  prompts_log = np.stack((name_prediction, prompt_prediction, preds_firsthalf, preds_2ndhalf))\n",
        "\n",
        "  consistency_log = np.array([int(np.all(prompts_log[:, i] == prompts_log[:, i][0])) for i in tqdm(range(prompts_log.shape[-1]))])\n",
        "\n",
        "  unstable_id = np.where(consistency_log==0)\n",
        "\n",
        "  return consistency_log, unstable_id"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6_1-Kd1oGtEA"
      },
      "outputs": [],
      "source": [
        "# continuous confidence score calculation\n",
        "def continuous_uncertainty_estimation(prompts=clip.PROMPTS, name_prediction=top5_name_labels[:, 0]):\n",
        "  \"\"\"Obtain continuous confidence score for each image\"\"\"\n",
        "\n",
        "  # prompts decision\n",
        "  prompts_slice_log = []\n",
        "  for sample_id in range(len(prompts)): # for each different prompts\n",
        "    prompteng_slice = zeroshot_classifier(clip.IMAGENET_CLASSES, [prompts[sample_id]]) #select one each time\n",
        "    logits_slice = np.matmul(embeddings, prompteng_slice) # compute the logits\n",
        "    probs_slice = np.array(jax.nn.softmax(logits_slice, axis=-1))\n",
        "    preds_slice = jnp.argmax(probs_slice, axis=-1)\n",
        "    prompts_slice_log.append(preds_slice)\n",
        "\n",
        "  prompts_slice_log_array = np.array(prompts_slice_log) # (len(prompts), 50000)\n",
        "\n",
        "  # compare with no-prompts and use consistency to get the confidence score\n",
        "  consistency_with_nonpromp = prompts_slice_log_array == name_prediction\n",
        "  consistency_with_nonpromp_log = sum(consistency_with_nonpromp, axis=0)\n",
        "  confidence_sort_index = numpy.argsort(consistency_with_nonpromp_log) # form small to large\n",
        "  \n",
        "  return prompts_slice_log_array, consistency_with_nonpromp_log, confidence_sort_index"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vpu-BSCfVWNU"
      },
      "outputs": [],
      "source": [
        "def compute_accuracy_rerank(logits, ori_labels, labels):\n",
        "  \"\"\"Based on logits, reorder the ori_top5_labels and compare with true labels\"\"\"\n",
        "\n",
        "  top_probs, top_labels_index = jax.lax.top_k(logits, 1)\n",
        "  top_labels = np.array([ele[top_labels_index[idx]] for idx, ele in enumerate(tqdm(ori_labels))])\n",
        "  top1 = 100 * jnp.mean(top_labels[:, 0] == labels)\n",
        "  return top1, top_probs, top_labels"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "q81hCYl6qA84"
      },
      "outputs": [],
      "source": [
        "# Bottom-up and top-down augmented inference\n",
        "def hierarchy_zero_shot_inference(wordnet_an_des_dict, imagenet_bottom_up_all_classes, imagenet_bottom_up_mapping_index_list, confidence_sort_index, embeddings=embeddings, name_prediction=top5_name_labels[:, 0], unstable_num=10000, fill_empty_parent = ''):\n",
        "  \"\"\"Based on logits, reorder the ori_top5_labels and compare with true labels\n",
        "  Args:\n",
        "    wordnet_an_des_dict: wordnet hierarchy dict.\n",
        "    imagenet_bottom_up_all_classes: augmented imagenet class list after considering bottom up augmentation\n",
        "    imagenet_bottom_up_mapping_index_list: index list to help update logits for each top-5\n",
        "    confidence_sort_index: imagenet images id sorted by confidence score \n",
        "    embeddings: imagenet image embedding\n",
        "    name_prediction: predicted classes without prompts (anchor)\n",
        "    unstable_num: number of selected low confidence images\n",
        "    fill_empty_parent: for those classes that has no parent in the sparcified wordnet, fill nothing\n",
        "\n",
        "  Returns:\n",
        "    acc_stable: accuracy for selected high confidence images \n",
        "    unstable_top1_acc: accuracy for selected low confidence images \n",
        "    overall_acc: accuracy for all imagenet validation images \n",
        "  \"\"\"\n",
        "\n",
        "  # stable set and unstable set\n",
        "  unstable_index = confidence_sort_index[:unstable_num] # number of rejected/unstable images\n",
        "  stable_index = np.array(list(set(range(len(labels))) - set(unstable_index.tolist()))) # stable, np.array\n",
        "\n",
        "  # 1. accuracy of the stable samples:\n",
        "  # Manually computed accuracy\n",
        "  acc_stable = np.mean(name_prediction[stable_index] == labels[stable_index])\n",
        "  acc_stable_log.append(acc_stable)\n",
        "  print(acc_stable)\n",
        "\n",
        "  # 2. accuracy of the unstable samples: use our hierarchy\n",
        "  baseline_top5 = top5_name_labels[unstable_index]\n",
        "  baseline_top5_probs = top5_name_probs[unstable_index]\n",
        "  diff_image_embeddings = embeddings[unstable_index]\n",
        "  unstable_id_gt_labels = labels[unstable_index]\n",
        "\n",
        "  unstable_case = top5_name_labels[unstable_index]   # original top-5 prediction\n",
        "  unstable_case_label = labels[unstable_index] # top-5 GT labels\n",
        "\n",
        "  # [Top-5 Ancestor collection] calculate Ancestor from simple wordnet for unstable cases\n",
        "  A_all = []\n",
        "  for example in unstable_case: # for each unstable image\n",
        "    # ancestor of top 5\n",
        "    top5_ancestor = []\n",
        "    # top5_ancestor_raw = []\n",
        "    for ind, top_k in enumerate(example):\n",
        "      if wordnet_an_des_dict[str(example[ind])]['ancestors']==[]: # no ancestor\n",
        "        top5_ancestor.append(fill_empty_parent)\n",
        "      else:\n",
        "        top5_ancestor.append(wordnet_an_des_dict[str(example[ind])]['ancestors'].copy()[0]) # avoid influence the root\n",
        "    A_all.append(top5_ancestor)\n",
        "\n",
        "\n",
        "  # add ancestor to all candidate classes\n",
        "  logits_ori_all = []\n",
        "  for idx, unstable_case_ele in tqdm(enumerate(unstable_case)): # for each unstable image\n",
        "    # original word embedding\n",
        "    word_list_ori = np.array(clip.IMAGENET_CLASSES)[np.array(unstable_case_ele)].tolist()\n",
        "    word_list_hierarchy = [] # raw words with only [bottom up]\n",
        "    word_list_hierarchy_with_prompt = [] # words add prompt and ancestor [bottom up + top down]\n",
        "    wordid_list_hierarchy_list = [] # details of index [[1,2], [3,4,5], [6], [7,8], [9]]\n",
        "    wordid_list_hierarchy_mapping_list = [] # used for reduceat (index of first element for each top-5) [1,3,6,7,9]\n",
        "    i = 0 # tracking each class/child class\n",
        "    for word_id, word in enumerate(word_list_ori): # for each class in top 5\n",
        "      tempid_list = [] # for specific word in top5\n",
        "      wordid_list_hierarchy_mapping_list.append(i)\n",
        "      for child_id in imagenet_bottom_up_mapping_index_list[unstable_case_ele[word_id]]: # for each children or element [Bottom-up]\n",
        "        raw_word = imagenet_bottom_up_all_classes[child_id-1]\n",
        "        word_list_hierarchy.append(raw_word) # child_id start from 1 not 0\n",
        "        word_list_hierarchy_with_prompt.append(raw_word + ' which is a kind of' + A_all[idx][word_id]) # [Top-down]\n",
        "        tempid_list.append(i)\n",
        "        i+=1\n",
        "      wordid_list_hierarchy_list.append(tempid_list)\n",
        "    word_list = word_list_hierarchy_with_prompt # could change\n",
        "\n",
        "    # inference\n",
        "    word_weights_ori = zeroshot_classifier(word_list, clip.PROMPTS) # (512, 5)\n",
        "    img_embedding = diff_image_embeddings[idx] # (1, 512) single image\n",
        "    logits_ori_raw = np.matmul(img_embedding, word_weights_ori)\n",
        "    logits_ori_raw = logits_ori_raw[np.newaxis, :]\n",
        "    logits_ori = np.maximum.reduceat(logits_ori_raw, indices=wordid_list_hierarchy_mapping_list, axis=1)\n",
        "    logits_ori_all.append(logits_ori)\n",
        "  logits_ori_all_array = np.array(logits_ori_all)[:, 0, :]\n",
        "\n",
        "  # Manually computed accuracy\n",
        "  unstable_top1_acc, unstable_top_probs, unstable_top_labels = compute_accuracy_rerank(logits_ori_all_array, unstable_case, unstable_case_label)\n",
        "\n",
        "  # compute overall accuracy\n",
        "  overall_acc = (acc_stable*100*len(stable_index) + unstable_top1_acc*len(unstable_index))/len(labels)\n",
        "\n",
        "return acc_stable, unstable_top1_acc, overall_acc\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "zero-shot clip evaluation",
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/zero-shot-multi-modal/zero_shot_clip_evaluation.ipynb",
          "timestamp": 1671824820712
        },
        {
          "file_id": "1pYzvrxFxq49N4pRjr_EqO5G88XcbymJ2",
          "timestamp": 1660953918805
        },
        {
          "file_id": "1GYA4lgUMzAggr0NI6YmpjxG-bEkVzqts",
          "timestamp": 1660839172150
        },
        {
          "file_id": "1kIMxAALvh_VjVlbb8R1jIr1_9jDxKwT1",
          "timestamp": 1660764437571
        },
        {
          "file_id": "1t3SaruLXW05sYXbktqrv02SEy7uUPtLm",
          "timestamp": 1657237417294
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
